# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Contains the base ProcessingLayer and a subclass that uses Combiners."""

import abc

from keras.engine import data_adapter
from keras.engine.base_layer import Layer
from keras.utils import version_utils
import tensorflow.compat.v2 as tf
# pylint: disable=g-direct-tensorflow-import
from tensorflow.python.eager import context
from tensorflow.python.util.tf_export import keras_export
from tensorflow.tools.docs import doc_controls


keras_kpl_gauge = tf.__internal__.monitoring.BoolGauge(
    '/tensorflow/api/keras/layers/preprocessing',
    'keras preprocessing layers usage', 'method')


@keras_export('keras.layers.experimental.preprocessing.PreprocessingLayer')
class PreprocessingLayer(Layer, metaclass=abc.ABCMeta):
  """Base class for Preprocessing Layers.

  **Don't use this class directly: it's an abstract base class!** You may
  be looking for one of the many built-in
  [preprocessing layers](https://keras.io/guides/preprocessing_layers/)
  instead.

  Preprocessing layers are layers whose state gets computed before model
  training starts. They do not get updated during training.
  Most preprocessing layers implement an `adapt()` method for state computation.

  The `PreprocessingLayer` class is the base class you would subclass to
  implement your own preprocessing layers.
  """
  _must_restore_from_config = True

  def __init__(self, **kwargs):
    super(PreprocessingLayer, self).__init__(**kwargs)
    self._is_compiled = False
    self._is_adapted = False

    # Sets `is_adapted=False` when `reset_state` is called.
    self._reset_state_impl = self.reset_state
    self.reset_state = self._reset_state_wrapper

    self._adapt_function = None

  @property
  def is_adapted(self):
    """Whether the layer has been fit to data already."""
    return self._is_adapted

  @doc_controls.do_not_generate_docs
  def update_state(self, data):
    """Accumulates statistics for the preprocessing layer.

    Arguments:
      data: A mini-batch of inputs to the layer.
    """
    raise NotImplementedError

  @doc_controls.do_not_generate_docs
  def reset_state(self):  # pylint: disable=method-hidden
    """Resets the statistics of the preprocessing layer."""
    raise NotImplementedError

  @doc_controls.do_not_generate_docs
  def finalize_state(self):
    """Finalize the statistics for the preprocessing layer.

    This method is called at the end of `adapt` or after restoring a serialized
    preprocessing layer's state. This method handles any one-time operations
    that should occur on the layer's state before `Layer.__call__`.
    """
    pass

  @doc_controls.do_not_generate_docs
  def make_adapt_function(self):
    """Creates a function to execute one step of `adapt`.

    This method can be overridden to support custom adapt logic.
    This method is called by `PreprocessingLayer.adapt`.

    Typically, this method directly controls `tf.function` settings,
    and delegates the actual state update logic to
    `PreprocessingLayer.update_state`.

    This function is cached the first time `PreprocessingLayer.adapt`
    is called. The cache is cleared whenever `PreprocessingLayer.compile`
    is called.

    Returns:
      Function. The function created by this method should accept a
      `tf.data.Iterator`, retrieve a batch, and update the state of the
      layer.
    """
    if self._adapt_function is not None:
      return self._adapt_function

    def adapt_step(iterator):
      data = next(iterator)
      self._adapt_maybe_build(data)
      self.update_state(data)

    if self._steps_per_execution.numpy().item() == 1:
      adapt_fn = adapt_step
    else:

      def adapt_fn(iterator):
        for _ in tf.range(self._steps_per_execution):
          adapt_step(iterator)

    if not self._run_eagerly:
      adapt_fn = tf.function(adapt_fn)

    self._adapt_function = adapt_fn
    return self._adapt_function

  def compile(self, run_eagerly=None, steps_per_execution=None):
    """Configures the layer for `adapt`.

    Arguments:
      run_eagerly: Bool. Defaults to `False`. If `True`, this `Model`'s logic
        will not be wrapped in a `tf.function`. Recommended to leave this as
        `None` unless your `Model` cannot be run inside a `tf.function`.
        steps_per_execution: Int. Defaults to 1. The number of batches to run
          during each `tf.function` call. Running multiple batches inside a
          single `tf.function` call can greatly improve performance on TPUs or
          small models with a large Python overhead.
    """
    if steps_per_execution is None:
      steps_per_execution = 1
    self._configure_steps_per_execution(steps_per_execution)

    if run_eagerly is None:
      run_eagerly = self.dynamic
    self._run_eagerly = run_eagerly

    self._is_compiled = True

  def adapt(self, data, batch_size=None, steps=None):
    """Fits the state of the preprocessing layer to the data being passed.

    After calling `adapt` on a layer, a preprocessing layer's state will not
    update during training. In order to make preprocessing layers efficient in
    any distribution context, they are kept constant with respect to any
    compiled `tf.Graph`s that call the layer. This does not affect the layer use
    when adapting each layer only once, but if you adapt a layer multiple times
    you will need to take care to re-compile any compiled functions as follows:

     * If you are adding a preprocessing layer to a `keras.Model`, you need to
       call `model.compile` after each subsequent call to `adapt`.
     * If you are calling a preprocessing layer inside `tf.data.Dataset.map`,
       you should call `map` again on the input `tf.data.Dataset` after each
       `adapt`.
     * If you are using a `tf.function` directly which calls a preprocessing
       layer, you need to call `tf.function` again on your callable after
       each subsequent call to `adapt`.

    `tf.keras.Model` example with multiple adapts:

    >>> layer = tf.keras.layers.Normalization(
    ...     axis=None)
    >>> layer.adapt([0, 2])
    >>> model = tf.keras.Sequential(layer)
    >>> model.predict([0, 1, 2])
    array([-1.,  0.,  1.], dtype=float32)
    >>> layer.adapt([-1, 1])
    >>> model.compile() # This is needed to re-compile model.predict!
    >>> model.predict([0, 1, 2])
    array([0., 1., 2.], dtype=float32)

    `tf.data.Dataset` example with multiple adapts:

    >>> layer = tf.keras.layers.Normalization(
    ...     axis=None)
    >>> layer.adapt([0, 2])
    >>> input_ds = tf.data.Dataset.range(3)
    >>> normalized_ds = input_ds.map(layer)
    >>> list(normalized_ds.as_numpy_iterator())
    [array([-1.], dtype=float32),
     array([0.], dtype=float32),
     array([1.], dtype=float32)]
    >>> layer.adapt([-1, 1])
    >>> normalized_ds = input_ds.map(layer) # Re-map over the input dataset.
    >>> list(normalized_ds.as_numpy_iterator())
    [array([0.], dtype=float32),
     array([1.], dtype=float32),
     array([2.], dtype=float32)]

    `adapt()` is meant only as a single machine utility to compute layer state.
    To analyze a dataset that cannot fit on a single machine, see
    [Tensorflow Transform](https://www.tensorflow.org/tfx/transform/get_started)
    for a multi-machine, map-reduce solution.

    Arguments:
        data: The data to train on. It can be passed either as a tf.data
          Dataset, or as a numpy array.
        batch_size: Integer or `None`.
            Number of samples per state update.
            If unspecified, `batch_size` will default to 32.
            Do not specify the `batch_size` if your data is in the
            form of datasets, generators, or `keras.utils.Sequence` instances
            (since they generate batches).
        steps: Integer or `None`.
            Total number of steps (batches of samples)
            When training with input tensors such as
            TensorFlow data tensors, the default `None` is equal to
            the number of samples in your dataset divided by
            the batch size, or 1 if that cannot be determined. If x is a
            `tf.data` dataset, and 'steps' is None, the epoch will run until
            the input dataset is exhausted. When passing an infinitely
            repeating dataset, you must specify the `steps` argument. This
            argument is not supported with array inputs.
    """
    _disallow_inside_tf_function('adapt')
    if not version_utils.should_use_v2():
      raise RuntimeError('`adapt` is only supported in tensorflow v2.')  # pylint: disable=g-doc-exception
    if not self._is_compiled:
      self.compile()  # Compile with defaults.
    if self.built:
      self.reset_state()
    data_handler = data_adapter.DataHandler(
        data,
        batch_size=batch_size,
        steps_per_epoch=steps,
        epochs=1,
        steps_per_execution=self._steps_per_execution,
        distribute=False)
    self._adapt_function = self.make_adapt_function()
    for _, iterator in data_handler.enumerate_epochs():
      with data_handler.catch_stop_iteration():
        for _ in data_handler.steps():
          self._adapt_function(iterator)
          if data_handler.should_sync:
            context.async_wait()
    self.finalize_state()
    self._is_adapted = True

  def _reset_state_wrapper(self):
    """Calls `reset_state` and sets `adapted` to `False`."""
    self._reset_state_impl()
    self._is_adapted = False

  @tf.__internal__.tracking.no_automatic_dependency_tracking
  def _configure_steps_per_execution(self, steps_per_execution):
    self._steps_per_execution = tf.Variable(
        steps_per_execution,
        dtype='int64',
        aggregation=tf.VariableAggregation.ONLY_FIRST_REPLICA)

  # TODO(omalleyt): Unify this logic with `Layer._maybe_build`.
  def _adapt_maybe_build(self, data):
    if not self.built:
      try:
        # If this is a Numpy array or tensor, we can get shape from .shape.
        # If not, an attribute error will be thrown.
        data_shape = data.shape
        data_shape_nones = tuple([None] * len(data.shape))
      except AttributeError:
        # The input has an unknown number of dimensions.
        data_shape = None
        data_shape_nones = None

      # TODO (b/159261555): move this to base layer build.
      batch_input_shape = getattr(self, '_batch_input_shape', None)
      if batch_input_shape is None:
        # Set the number of dimensions.
        self._batch_input_shape = data_shape_nones
      self.build(data_shape)
      self.built = True


def _disallow_inside_tf_function(method_name):
  """Disallow calling a method inside a `tf.function`."""
  if tf.inside_function():
    error_msg = (
        'Detected a call to `PreprocessingLayer.{method_name}` inside a '
        '`tf.function`. `PreprocessingLayer.{method_name} is a high-level '
        'endpoint that manages its own `tf.function`. Please move the call '
        'to `PreprocessingLayer.{method_name}` outside of all enclosing '
        '`tf.function`s. Note that you can call a `PreprocessingLayer` '
        'directly on `Tensor`s inside a `tf.function` like: `layer(x)`, '
        'or update its state like: `layer.update_state(x)`.').format(
            method_name=method_name)
    raise RuntimeError(error_msg)
